{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 自定义训练\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/middleclass/custom/train.ipynb)\n",
    "\n",
    "MindSpore提供了`model.train`接口来进行模型训练。使用方式可以参考[初级教程-初学入门](https://www.mindspore.cn/tutorial/zh-CN/master/quick_start.html)。此外，还可以使用`TrainOneStepCell`，该接口当前支持GPU、Ascend环境。\n",
    "\n",
    "作为高阶接口，`model.train`封装了`TrainOneStepCell`，可以直接利用设定好的网络、损失函数与优化器进行训练。用户也可以选择使用`TrainOneStepCell`实现更加灵活的训练，例如控制训练数据集、实现多输入多输出网络、或自定义训练过程。\n",
    "\n",
    "## TrainOneStepCell说明\n",
    "\n",
    "`TrainOneStepCell`中包含三种入参：\n",
    "\n",
    "- network (Cell)：参与训练的网络，当前仅接受单输出网络。\n",
    "\n",
    "- optimizer (Cell)：所使用的优化器。\n",
    "\n",
    "- sens (Number)：反向传播的缩放比例。\n",
    "\n",
    "下面使用`TrainOneStepCell`替换`model.train`，实现简单的线性网络训练过程。\n",
    "\n",
    "## TrainOneStepCell使用示例\n",
    "\n",
    "### 创建模型并生成数据\n",
    "\n",
    "> 本小节详细解释说明可参考[初级教程-初学入门](https://www.mindspore.cn/tutorial/zh-CN/master/quick_start.html)。\n",
    "\n",
    "定义网络LinearNet，内部有两层全连接层组成的网络， 包含5个入参和1个出参的神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from mindspore import Tensor\n",
    "import mindspore.nn as nn\n",
    "from mindspore.nn import Cell, Dense\n",
    "import mindspore.ops as ops\n",
    "import mindspore.dataset as ds\n",
    "from mindspore import ParameterTuple\n",
    "\n",
    "class LinearNet(Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.dense1 = Dense(5, 32)\n",
    "        self.dense2 = Dense(32, 1)\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.dense1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.dense2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "产生输入数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "np.random.seed(4)\n",
    "\n",
    "class DatasetGenerator:\n",
    "    def __init__(self):\n",
    "        self.data = np.random.randn(5, 5).astype(np.float32)\n",
    "        self.label = np.random.randn(5, 1).astype(np.int32)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.data[index], self.label[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "数据处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 对输入数据进行处理\n",
    "dataset_generator = DatasetGenerator()\n",
    "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=True)\n",
    "dataset = dataset.batch(32)\n",
    "\n",
    "# 实例化网络\n",
    "net = LinearNet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 定义TrainOneStepCell\n",
    "\n",
    "在`TrainOneStepCell`中，可以实现对训练过程的个性化设定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class TrainOneStepCell(nn.Cell):\n",
    "    def __init__(self, network, optimizer, sens=1.0):\n",
    "        \"\"\"参数初始化\"\"\"\n",
    "        super(TrainOneStepCell, self).__init__(auto_prefix=False)\n",
    "        self.network = network\n",
    "        # 使用tuple包装weight\n",
    "        self.weights = ParameterTuple(network.trainable_params())\n",
    "        self.optimizer = optimizer\n",
    "        # 定义梯度函数\n",
    "        self.grad = ops.GradOperation(get_by_list=True, sens_param=True)\n",
    "        self.sens = sens\n",
    "\n",
    "    def construct(self, data, label):\n",
    "        \"\"\"构建训练过程\"\"\"\n",
    "        weights = self.weights\n",
    "        loss = self.network(data, label)\n",
    "        # 为反向传播设定系数\n",
    "        sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens)\n",
    "        grads = self.grad(self.network, weights)(data, label, sens)\n",
    "        return loss, self.optimizer(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 网络训练\n",
    "\n",
    "在使用`TrainOneStepCell`时，需要利用`WithLossCell`接口引入损失函数，共同完成训练过程。下面利用之前设定好的参数训练LeNet网络，并获取loss值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7998974\n",
      "0.79927444\n",
      "0.7986423\n",
      "0.7979911\n",
      "0.79732\n",
      "... ...\n",
      "0.042837422\n",
      "0.041227795\n",
      "0.039638687\n",
      "... ...\n",
      "9.276913e-06\n",
      "8.4145695e-06\n",
      "7.625091e-06\n",
      "6.904066e-06\n",
      "6.2513377e-06\n"
     ]
    }
   ],
   "source": [
    "# 设定损失函数\n",
    "crit = nn.MSELoss()\n",
    "# 设定优化器\n",
    "opt = nn.Adam(params=net.trainable_params())\n",
    "# 引入损失函数\n",
    "net_with_criterion = nn.WithLossCell(net, crit)\n",
    "# 自定义网络训练\n",
    "train_net = TrainOneStepCell(net_with_criterion, opt)\n",
    "\n",
    "# 获取训练过程数据\n",
    "\n",
    "for i in range(300):\n",
    "    for d in dataset.create_dict_iterator():\n",
    "        train_net(d[\"data\"], d[\"label\"])\n",
    "        print(net_with_criterion(d[\"data\"], d[\"label\"]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore-1.3.0",
   "language": "python",
   "name": "mindspore-1.3.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}